Skip to content

Nanz/overload factor logging#4110

Draft
nanz-nv wants to merge 4 commits intoNVIDIA:devfrom
nanz-nv:nanz/overload_factor_logging
Draft

Nanz/overload factor logging#4110
nanz-nv wants to merge 4 commits intoNVIDIA:devfrom
nanz-nv:nanz/overload_factor_logging

Conversation

@nanz-nv
Copy link
Copy Markdown
Contributor

@nanz-nv nanz-nv commented Apr 2, 2026

What does this PR do ?

This PR introduces a utility to log overload factor through log_overload_factor.

MoE overload factor

Overload factor is the ratio of the token count on the most loaded rank in a TP-EP group to the balanced token count per rank—the count each rank would see with perfectly balanced routing. It measures workload imbalance.

moe/avg_overload_factor

Arithmetic mean of overload factor over all layer × microbatch x DP slice entries recorded on this rank in the step.

moe/max_overload_factor

Maximum of overload factor over all layer × microbatch x DP slice entries. Useful for estimating peak buffer size for intermediate activations in the forward path.

moe/max_cum_overload_factor

Reflects the cumulative fwd/bwd tokensL ratio of peak cumulative actual tokens to peak cumulative balanced count. Useful for estimating how much activation-related memory may need to be retained through backward.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 2, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link
Copy Markdown
Contributor

@Victarry Victarry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nanz-nv , thanks for this PR, it would be very useful for tracing MoE routing information.

Above are mixed comments by me and Claude Code, please take a look~

)
if (
hasattr(self, "_inference_token_dispatcher")
and self.is_inference_cuda_graphed_iteration
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The log_overload_factor block inside experts_compute_dispatch (balanced count retrieval, dispatcher type check, tensor construction, hook registration) should be extracted into a private method such as _record_overload_factor(self, dispatched_input, tokens_per_expert).

Mixing this logic into experts_compute_dispatch hurts readability. A single call site keeps the dispatch method focused on dispatch logic.

if rm is not None:
return rm
cm = getattr(td, "_comm_manager", None)
if cm is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Avoid abbreviated variable names in parallel/distributed code where clarity is critical:

  • tdtoken_dispatcher
  • rmrouting_map
  • cmcomm_manager

Same applies to td (L501) and ws (L495) in experts_compute_dispatch. Use tp_ep_world_size for ws.

Flex/HybridEP keep the map on ``_comm_manager``.
"""
td = self.token_dispatcher
rm = getattr(td, "routing_map", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The balanced token count (routing_map.shape[0] * topk) could be computed earlier in MoELayer.forward() directly from hidden_states before any dispatch, rather than reading routing_map in this post-token_dispatch window.

Current approach has two fragilities:

  1. Requires routing_map to still be alive after token_dispatch but before dispatch_postprocess clears it — a timing assumption tied to dispatcher internals.
  2. Needs separate handling for AllGather (routing_map attr) vs Flex/HybridEP (_comm_manager.routing_map) dispatchers — coupling to internal implementation details.

Computing from hidden_states.shape[0] in MoELayer.forward() would remove _routing_map_after_token_dispatch entirely.

local_balanced = torch.empty(
(), device=dispatched_input.device, dtype=torch.float32
)
local_balanced.fill_(base)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] torch.empty(()) + fill_() is unnecessarily verbose. Use torch.tensor directly:

local_balanced = torch.tensor(base, device=dispatched_input.device, dtype=torch.float32)

tokens_on_rank = tokens_per_expert.detach().sum()
if not tokens_on_rank.is_floating_point():
tokens_on_rank = tokens_on_rank.float()
tokens_on_rank = tokens_on_rank.to(device=tensor.device, dtype=torch.float32).reshape(())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] .reshape(()) is a no-op here — tokens_per_expert.detach().sum() already returns a 0-dim tensor. Same applies to balanced on L1016. Both .reshape(()) calls can be removed.

device=device,
dtype=torch.float32,
)
torch.distributed.all_reduce(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] report() contains 6–7 sequential all_reduce calls across the same groups. Some can be fused to reduce collective launch overhead:

  • tp_ep: all_reduce(max_actual, MAX) (L283) and all_reduce(balanced_stacked, SUM) (L295) can be packed into one call by stacking both tensors.
  • dp: all_reduce(overload_avg, AVG) (L303) and all_reduce(overload_max, MAX) (L307) operate on the same data — consider a single fused reduce.
  • ratio_t: two sequential all_reduces (tp_ep MAX L268, dp MAX L272) could be deferred and folded into the tp_ep and dp passes above.

At scale (hundreds of MoE layers, large DP), reducing the number of collectives per logging step is meaningful.

if self._pending_clear:
self._pending_clear = False
self._clear_storage()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The _pending_clear deferred-clear mechanism adds complexity that does not deliver its intended benefit.

The rationale (from the class docstring) is to keep tensor handles valid during CUDA graph replay windows. However, CUDA graph replay does not re-execute Python-side autograd functions — record_fwd and record_bwd are never called during replay. So the tracker never receives new data during replay regardless of whether storage has been cleared, making the deferred-clear protection moot.

For non-CUDA-graph training the extra state (_pending_clear flag + _flush_pending_clear() call in every record_fwd/record_bwd) is pure overhead with no benefit.

Suggestion: Have clear() call _clear_storage() directly and remove _pending_clear and _flush_pending_clear().

g = parallel_state.get_pipeline_model_parallel_group(check_initialized=False)
return g

def report(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] report() is ~200 lines and mixes several independent responsibilities: tp_ep reduction, dp reduction, cumsum peak computation, pp reduction, TensorBoard/W&B logging, and log-string assembly. This makes it hard to test, read, and debug any single concern.

Suggest splitting into focused private methods:

  • _reduce_tp_ep(fwd_tensors, balanced_tensors) → returns tp_ep_overload
  • _reduce_dp(tp_ep_overload) → returns (overload_avg, overload_max)
  • _compute_max_cum_overload() → returns max_cum_overload_factor
  • _reduce_pp(max_overload, max_cum) → returns pp-reduced scalars
  • _log_to_writers(writer, wandb_writer, scalars, iteration) → writes TB/W&B

report() then becomes a thin orchestrator responsible only for call ordering.

Comment on lines +96 to +119
"""Tracker for MoE overload factor metrics.

Records per-layer **tokens on this rank** after dispatch
(``tokens_per_expert.sum()``) and a pre-dispatch **balanced token count** scalar
(from ``routing_map.shape[0] * moe_router_topk`` read after ``token_dispatch``),
via an autograd hook on
``dispatched_input``. ``report()`` does ``all_reduce(MAX)`` on per-rank actual totals
over ``tp_ep_group``, divides by balanced count per rank (from summed local counts / size)
to get **tp_ep overload** per microbatch entry, then ``all_reduce(AVG)`` and
``all_reduce(MAX)`` on that overload across ``dp_group`` before scalar summaries.
Over the **pipeline-parallel** group, ``max`` and ``max_cum`` scalars are
``all_reduce(MAX)`` so every stage agrees on the worst overload; ranks without
MoE layers contribute ``0``. The **mean** overload scalar is **not** reduced
across PP (each rank logs its local mean, ``0`` if it recorded nothing).
``_fwd_bwd`` / ``_fwd_bwd_balanced`` mirror interleaved fwd/bwd so cumulative
peaks of actual vs balanced token counts can be compared.

Lifecycle: set_process_groups() and record_fwd/record_bwd during forward
(SaveOverloadFactorFunction in MoELayer) → report() at step end
(sync, aggregate, log, deferred clear) → repeat.

``clear()`` only marks storage for reset on the next ``record_fwd`` or
``record_bwd`` so tensor handles stay valid until Python runs a recording
hook again (e.g. across CUDA graph replay windows that skip those hooks).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments could be refactored with AI for better readability

Example comment refactored by GPT

    """Track MoE overload-factor metrics.

    Recorded values
    ---------------
    - Per-layer actual tokens on this rank after dispatch:
      tokens_per_expert.sum().
    - Per-layer balanced token count before dispatch:
      routing_map.shape[0] * moe_router_topk (read after token_dispatch).
    - Both values are captured by an autograd hook on dispatched_input.
    - _fwd_bwd and _fwd_bwd_balanced mirror interleaved fwd/bwd events so
      cumulative peaks of actual vs balanced token counts can be compared.

    How report() aggregates
    -----------------------
    1. In tp_ep_group, run all_reduce(MAX) on per-rank actual totals.
    2. Divide by balanced tokens per rank (summed local balanced counts / group size)
       to get per-entry tp_ep overload.
    3. In dp_group, run all_reduce(AVG) and all_reduce(MAX) on overload
       before scalar summaries.
    4. In the pipeline-parallel group, max and max_cum use all_reduce(MAX) so
       every stage agrees on the worst overload. Ranks without MoE layers
       contribute 0.
    5. Mean overload is not reduced across PP; each rank logs its local mean
       (0 if nothing was recorded).

    Lifecycle
    ---------
    set_process_groups() and record_fwd()/record_bwd() are called during forward
    (SaveOverloadFactorFunction in MoELayer). report() runs at step end
    (sync, aggregate, log, deferred clear), then the cycle repeats.

    clear() behavior
    ----------------
    clear() does not immediately reset storage. It marks storage for reset on
    the next record_fwd() or record_bwd() so tensor handles stay valid until
    Python executes a recording hook again (for example across CUDA graph replay
    windows that skip those hooks).

return grad_output, None, None, None


def save_overload_factor_to_tracker(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] save_overload_factor_to_tracker and SaveOverloadFactorFunction are misleadingly named — neither computes nor saves an overload factor. What they actually do is record post-dispatch token counts (actual tokens on this rank + balanced token count) into the tracker. The overload factor itself is only computed later in report().

Suggested renames:

  • SaveOverloadFactorFunctionRecordDispatchTokenCountsFunction
  • save_overload_factor_to_trackerrecord_dispatch_token_counts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants